{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "d:\\Anaconda3\\envs\\machine_learning\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Using cuda:0\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x26f0e7679b0>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "import h5py\n",
    "import copy\n",
    "import os\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "\n",
    "from torchvision.models import *\n",
    "from torchvision.transforms import *\n",
    "from torch.utils.data import DataLoader, Dataset\n",
    "from torch.optim import *\n",
    "from torch.optim.lr_scheduler import *\n",
    "\n",
    "import torch_pruning as tp\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from tqdm.auto import tqdm\n",
    "\n",
    "from Utils.utils import *\n",
    "from Utils.dataset import *\n",
    "from Utils.model_loading import *\n",
    "\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(f\"Using {device}\")\n",
    "\n",
    "random.seed(0)\n",
    "np.random.seed(0)\n",
    "torch.manual_seed(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pruning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_size = 152\n",
    "dataloader = build_dataloader(\n",
    "    train_path = \"Data/mitbih_mif_train_small.h5\",\n",
    "    test_path  = \"Data/mitbih_mif_test.h5\",\n",
    "    batch_size = 32,\n",
    "    transform  = Resize((image_size, image_size), antialias=None),\n",
    ")\n",
    "\n",
    "model_name      = \"vgg16_bn_custom\"\n",
    "base_model_path = \"Pretrained/vgg16_bn_custom_ecg_ep50_i152.pth\"\n",
    "num_classes     = ArrhythmiaLabels.size\n",
    "base_model      = load_model_from_pretrained(model_name, base_model_path, num_classes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[I]: Benchmarking model Pruned\n",
      "[I]: \tGetting model params and MACs\n",
      "[I]: \tMeasuring latency\n",
      "[I]: \tEvaluating\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                       "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[I]: Benchmarking for Pruned finished\n",
      "Name:     Pruned\n",
      "Accuracy: 82.76%\n",
      "Latency:  7.0 ms\n",
      "Params:   3 M\n",
      "MACs:     284 M\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "# Copy base model\n",
    "pruned_model = copy.deepcopy(base_model).to(\"cpu\")\n",
    "pruned_model.eval()\n",
    "\n",
    "# Get dummy input for tracing\n",
    "dummy_input = torch.rand((1, 3, image_size, image_size))\n",
    "\n",
    "# Ignore the last classification layer\n",
    "ignored_layers = [pruned_model.classifier[6]]\n",
    "\n",
    "# Pruning objects\n",
    "imp = tp.importance.MagnitudeImportance(p=2)\n",
    "\n",
    "pruner = tp.pruner.MagnitudePruner(\n",
    "    model             = pruned_model,\n",
    "    example_inputs    = dummy_input,\n",
    "    importance        = imp,\n",
    "    global_pruning    = False,\n",
    "    pruning_ratio     = 0.8,\n",
    "    max_pruning_ratio = 1.0,\n",
    "    ignored_layers    = ignored_layers,\n",
    ")\n",
    "\n",
    "pruner.step()\n",
    "\n",
    "pruned_model_stats = benchmark_model(pruned_model, dataloader[\"test\"], name=\"Pruned\")\n",
    "display_model_stats(pruned_model_stats)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Finetuning\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1 Accuracy 95.15% / Best Accuracy: 95.15%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 2 Accuracy 97.71% / Best Accuracy: 97.71%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 3 Accuracy 87.70% / Best Accuracy: 97.71%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 4 Accuracy 97.27% / Best Accuracy: 97.71%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 5 Accuracy 97.25% / Best Accuracy: 97.71%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 6 Accuracy 97.72% / Best Accuracy: 97.72%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 7 Accuracy 97.54% / Best Accuracy: 97.72%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 8 Accuracy 95.85% / Best Accuracy: 97.72%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 9 Accuracy 97.40% / Best Accuracy: 97.72%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 10 Accuracy 97.88% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 11 Accuracy 96.39% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 12 Accuracy 96.47% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 13 Accuracy 97.50% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 14 Accuracy 90.00% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 15 Accuracy 97.72% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 16 Accuracy 97.87% / Best Accuracy: 97.88%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 17 Accuracy 98.14% / Best Accuracy: 98.14%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 18 Accuracy 98.26% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 19 Accuracy 97.65% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 20 Accuracy 97.49% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 21 Accuracy 98.22% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 22 Accuracy 97.48% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 23 Accuracy 97.83% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 24 Accuracy 97.70% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 25 Accuracy 96.72% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 26 Accuracy 97.58% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 27 Accuracy 98.14% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 28 Accuracy 97.31% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        \r"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 29 Accuracy 97.59% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "                                                        "
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 30 Accuracy 96.86% / Best Accuracy: 98.26%\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r"
     ]
    }
   ],
   "source": [
    "# Finetuning\n",
    "num_finetune_epochs = 30\n",
    "pruned_model.to(device)\n",
    "optimizer = SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n",
    "scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "best_model_checkpoint = dict()\n",
    "best_accuracy         = 0\n",
    "\n",
    "print(\"Finetuning\")\n",
    "for epoch in range(num_finetune_epochs):\n",
    "    train(pruned_model, dataloader[\"train\"], criterion, optimizer, scheduler)\n",
    "    accuracy = evaluate(pruned_model, dataloader[\"test\"])\n",
    "\n",
    "    if accuracy > best_accuracy:\n",
    "        best_model_checkpoint[\"state_dict\"] = copy.deepcopy(pruned_model.state_dict())\n",
    "        best_accuracy = accuracy\n",
    "    \n",
    "    print(f\"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%\")\n",
    "\n",
    "# Save best model\n",
    "if not os.path.exists(\"Pretrained/VGG-Pruned\"):\n",
    "    os.makedirs(\"Pretrained/VGG-Pruned\")\n",
    "    \n",
    "save_path = f\"Pretrained/VGG-Pruned/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}_p0.8.pth\"\n",
    "torch.save(best_model_checkpoint[\"state_dict\"], save_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ratios = [0.5, 0.6, 0.7, 0.8, 0.9]\n",
    "for ratio in ratios:\n",
    "\n",
    "    if ratio == 0.9:\n",
    "        max_prune = 0.95\n",
    "    else:\n",
    "        max_prune = 0.9\n",
    "\n",
    "    # Copy base model\n",
    "    pruned_model = copy.deepcopy(base_model).to(\"cpu\")\n",
    "    pruned_model.eval()\n",
    "\n",
    "    # Get dummy input for tracing\n",
    "    dummy_input = torch.rand((1, 3, image_size, image_size))\n",
    "\n",
    "    # Ignore the last classification layer\n",
    "    ignored_layers = [pruned_model.classifier[6]]\n",
    "\n",
    "    # Pruning objects\n",
    "    imp = tp.importance.MagnitudeImportance(p=2)\n",
    "\n",
    "    pruner = tp.pruner.MagnitudePruner(\n",
    "        model             = pruned_model,\n",
    "        example_inputs    = dummy_input,\n",
    "        importance        = imp,\n",
    "        global_pruning    = True,\n",
    "        pruning_ratio     = ratio,\n",
    "        max_pruning_ratio = max_prune,\n",
    "        ignored_layers    = ignored_layers,\n",
    "    )\n",
    "\n",
    "    pruner.step()\n",
    "\n",
    "    # Finetuning\n",
    "    num_finetune_epochs = 20\n",
    "    pruned_model.to(device)\n",
    "    optimizer = SGD(pruned_model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)\n",
    "    scheduler = CosineAnnealingLR(optimizer, num_finetune_epochs)\n",
    "    criterion = nn.CrossEntropyLoss()\n",
    "\n",
    "    best_model    = copy.deepcopy(pruned_model)\n",
    "    best_accuracy = 0\n",
    "\n",
    "    print(\"Finetuning\")\n",
    "    for epoch in range(num_finetune_epochs):\n",
    "        train(pruned_model, dataloader[\"train\"], criterion, optimizer, scheduler)\n",
    "        accuracy = evaluate(pruned_model, dataloader[\"test\"])\n",
    "\n",
    "        if accuracy > best_accuracy:\n",
    "            best_model = copy.deepcopy(pruned_model)\n",
    "            best_accuracy = accuracy\n",
    "        \n",
    "        print(f\"Epoch {epoch+1} Accuracy {accuracy:.2f}% / Best Accuracy: {best_accuracy:.2f}%\")\n",
    "\n",
    "    # Save best model\n",
    "    if not os.path.exists(\"Pretrained/VGG-Pruned\"):\n",
    "        os.makedirs(\"Pretrained/VGG-Pruned\")\n",
    "\n",
    "    save_path = f\"Pretrained/VGG-Pruned/{model_name}_ecg_ep{num_finetune_epochs}_i{image_size}_g{ratio}.pth\"\n",
    "    best_model.zero_grad()\n",
    "    torch.save(best_model, save_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "machine_learning",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
